-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move to ChainRulesCore v1.0 (in OptingOut) #1035
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally the mechanism seems fine. I guess that args_T
is always just concrete types. I do still have some questions on how the functions work - but seems fine overall
test/chainrules.jl
Outdated
# Now try opting out After we have already used it | ||
@opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Real) | ||
oa_id_rrule_hitcount[] = 0 | ||
oa_id_outer(x) = sum(oa_id(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we redefine the function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably a mistake
end | ||
|
||
do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m) | ||
if do_not_use_rrule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to do decomposition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is has_chain_rrule
and if it returns false,...
then pullback
will indeed end up calling generate_pullback_via_decomposition
.
Extra fixes for ChainRulesCore @1.0
minimal failing example for the remaining test failure is: using Zygote
W = ones(Float32, 3)
ps = Zygote.Params([W, ])
gs = gradient(ps) do
p, pb = pullback(ps) do
sum(W)
end
g = pb(p)
sum(g[W])
end (stacktrace)
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:87 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[4] Pullback
@ ~/JuliaEnvs/Zygote.jl/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[6] Pullback
@ ~/JuliaEnvs/Zygote.jl/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[10] getindex
@ ./tuple.jl:29 [inlined]
[11] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[12] Pullback
@ ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:348 [inlined]
[13] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[14] Pullback
@ ./REPL[5]:5 [inlined]
[15] (::typeof(∂(#5)))(Δ::Float32)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
[16] (::Zygote.var"#90#91"{Params, typeof(∂(#5)), Zygote.Context})(Δ::Float32)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:348
[17] gradient(f::Function, args::Params)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:76
[18] top-level scope
@ REPL[5]:1 where the failing expression is Removing |
That fails to me with the current released version of Zygote for me also.
|
Check with #823? |
Better MWE: that works on current release, but fails on this PR using ChainRulesCore
using Test
using Zygote
@testset "Params nesting" begin
struct Dense{F,T,S}
W::T
b::S
σ::F
end
(d::Dense)(x) = d.σ.(d.W * x .+ d.b)
d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
ps = Zygote.Params([d.W, d.b])
r = ones(Float32, 3,3)
gs = gradient(ps) do
p, pb = pullback(ps) do
sum(d(r))
end
g = pb(p)
sum(g[d.W]) # + sum(g[d.b])
end
end That particular issue is fixed by JuliaDiff/ChainRulesCore.jl#414 but other issues still remain. |
Probably worth it to merge #823 with a cleanup then. |
No, i fixed it to not need that. This is now passing all tests locally. And by JuliaDiff/ForwardDiff.jl#538 |
@@ -275,7 +304,7 @@ end | |||
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad | |||
) | |||
test_rrule( | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to change this due to an issue in ChainRulesTestUtils
JuliaDiff/ChainRulesTestUtils.jl#194
The Zygote and ChainRules code handles the string fine,
but ChainRulesTestUtils kinda freaks out about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can keep it as a gradtest
then? I'm not super sure if CRTU is expected to be a dependency for every AD test suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure gradtest
can't handle it either -- gradtest can't handle namedtuples containing tuples.
If you think it is important to have, to test what rrule_via_ad
does when confronted with a string,
I can add a test specifically for that.
I'm not super sure if CRTU is expected to be a dependency for every AD test suite.
It is, if you have rrule_via_ad
overloaded. (or rrule
)
Then you can use it to test all things.
It's a more robust version of gradtest
that can handle more types, and doesn't e.g. get tricked by antisymmetric errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its good to have the coverage over different data types, but I guess its going to need a future PR. I can already sense some SMILES like application making use of strings outside of embeddings.
@@ -18,7 +18,7 @@ using Zygote, Test, LinearAlgebra | |||
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10 | |||
|
|||
# https://github.com/FluxML/Zygote.jl/issues/705 | |||
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ im .* exp.(1:3) | |||
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because ChainRules takes embedded subspaces seriously.
Derivative of a real array can not be imaginary.
The thing that #705 was worried about is still fixed
@@ -449,12 +449,12 @@ end | |||
@test pullback(type_test)[1] == Complex{<:Real} | |||
|
|||
@testset "Pairs" begin | |||
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10 | |||
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is because we now take embedded sub-spaces seriously.
Integers are always considered to be a subspace of Floats.
Not just went it happens by coincidence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be compiling to integer code anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integers are not a good type to use to represent tangents.
Because if you are going to do gradeient decent you are going to apply a learning rate like 0.1*dx
.
So we call float
on them to get the corresponding floating point type.
(Unless it is a index or something, then it would be NoTangent()
)
Arguably we could keep integers here, but encouraging people to use integers to repressent continous values that just so happen do be integers feels like not taking subspace types seriously
🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a matter of letting the language take charge of these things. It's possible to work with just ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be moved to a issue on ChainRulesCore.jl?
Its not going to be resolved in this PR.
@@ -81,7 +81,7 @@ end | |||
@test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],) | |||
|
|||
@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,) | |||
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im) | |||
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ real((-234 + 2im)*log(5 - 7im)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again: primal was real so too much be it's derivative
@@ -275,7 +304,7 @@ end | |||
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad | |||
) | |||
test_rrule( | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can keep it as a gradtest
then? I'm not super sure if CRTU is expected to be a dependency for every AD test suite.
@@ -449,12 +449,12 @@ end | |||
@test pullback(type_test)[1] == Complex{<:Real} | |||
|
|||
@testset "Pairs" begin | |||
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10 | |||
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be compiling to integer code anyway?
Co-authored-by: Dhairya Gandhi <[email protected]>
# rrule: specific, no_rrule: fallback => !matches => do use rrule, as haven't opted out. | ||
# rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule | ||
# rrule: specific, no_rrule: specific => matches => do not use rrule as opted out | ||
# rrule: specific, no_rrule: general => !matches => do use rrule as a more specific rrule takes preciedent over more general opted out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So its the kind of complement of @nograd f(x::SomeType, y::SomeOtherType)
where we can still get grads for some methods, and not for others? Kind of defining more specific rrules and dispatching to them, but doing so manually here, rather than letting Julia do it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is like defining a more specific rrule
where that more specific rrule is "Let AD work it out".
But we can't use rrule_via_ad
for this because you hit a stackoverflow.
It is not really like @nograd
(@non_differentiable
is like @nograd
) except that both participate in rrule
dispatch via specificity
@@ -275,7 +304,7 @@ end | |||
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad | |||
) | |||
test_rrule( | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad | |||
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its good to have the coverage over different data types, but I guess its going to need a future PR. I can already sense some SMILES like application making use of strings outside of embeddings.
This PR will make the changes needed to support ChainRulesCore v1.0
A number of commits will be added to it to fix each breaking change.
Right now I am aware of 2 such changes
@opt_out
to say not to do AD. For Add opting out of rules JuliaDiff/ChainRulesCore.jl#398 (this is the biggest change)